#-*-Python-*-
"""Config for PAIRED architecture with no adversary RNN, for mini envs."""

import gin

import social_rl.adversarial_env.train_adversarial_env
import tf_agents.agents.ppo.ppo_agent

# Run function
xm_train.train_eval = @train_adversarial_env.train_eval

# Environment
train_adversarial_env.train_eval.env_name = 'MultiGrid-MiniAdversarial-v0'

# Adversarial params
train_adversarial_env.train_eval.agents_learn_with_regret = True
train_adversarial_env.train_eval.non_negative_regret = True
train_adversarial_env.train_eval.percent_random_episodes = 0.

# Training params
train_adversarial_env.train_eval.num_train_steps = 900000
train_adversarial_env.train_eval.replay_buffer_capacity = 4001
train_adversarial_env.train_eval.collect_episodes_per_iteration = 30
train_adversarial_env.train_eval.train_checkpoint_interval = 10000
train_adversarial_env.train_eval.policy_checkpoint_interval = 10000
train_adversarial_env.train_eval.log_interval = 100
train_adversarial_env.train_eval.eval_interval = 500
train_adversarial_env.train_eval.summary_interval = 500
train_adversarial_env.train_eval.num_parallel_envs = 30
train_adversarial_env.train_eval.num_eval_episodes = 10
train_adversarial_env.train_eval.debug = False

# Architecture params
train_adversarial_env.train_eval.actor_fc_layers = (32, 32)
train_adversarial_env.train_eval.value_fc_layers=(32, 32)
train_adversarial_env.train_eval.lstm_size=(256,)
train_adversarial_env.train_eval.conv_filters=16
train_adversarial_env.train_eval.conv_kernel=3
train_adversarial_env.train_eval.direction_fc=5
train_adversarial_env.train_eval.entropy_regularization=0.

# Environment architecture params
train_adversarial_env.train_eval.adversary_env_rnn = False
train_adversarial_env.train_eval.adv_actor_fc_layers = (128, 64)
train_adversarial_env.train_eval.adv_value_fc_layers = (128, 64)
train_adversarial_env.train_eval.adv_conv_filters=64
train_adversarial_env.train_eval.adv_conv_kernel=3
train_adversarial_env.train_eval.adv_timestep_fc=10
train_adversarial_env.train_eval.adv_entropy_regularization=0.1

# Agent params
ppo_agent.PPOAgent.discount_factor = 0.995
